package cn.doitedu.ml.demo

import cn.doitedu.commons.util.SparkUtil
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.expressions.{UserDefinedFunction, Window}
import org.apache.spark.sql.types.{DataTypes, StructType}

import scala.collection.mutable

/**
  * @date: 2020/2/19
  * @site: www.doitedu.cn
  * @author: hunter.d 涛哥
  * @qq: 657270652
  * @description:
  * KNN 人群分类计算
  */
object KNNClassify {

  def main(args: Array[String]): Unit = {

    // cn.doitedu.ml.demo.KNNClassify
    val logger = Logger.getLogger(this.getClass.getName)
    logger.setLevel(Level.DEBUG)

    // 屏蔽掉spark中自己的日志  spark的日志logger对象的名字都是 org.apache.spark.*
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)

    // 构建sparksession对象
    val spark = SparkUtil.getSparkSession(this.getClass.getSimpleName)
    import spark.implicits._
    logger.debug("sparksession 构建完成")

    // 加载样本数据文件
    // 加载csv的时候，字段类型可以让spark自己去推断 .option("inferSchema","true")
    // 但是，让spark自己去推断schema，spark需要专门生成一个job，加载整个数据集进行运算才能推断出来，对性能影响极其严重
    val schema1 = new StructType()
      .add("label",DataTypes.DoubleType)
      .add("f1",DataTypes.DoubleType)
      .add("f2",DataTypes.DoubleType)
      .add("f3",DataTypes.DoubleType)
      .add("f4",DataTypes.DoubleType)
      .add("f5",DataTypes.DoubleType)

    val sampleDF = spark.read.option("header", "true").schema(schema1)/*.option("inferSchema","true")*/.csv("userprofile/data/demo/knn/sample")
    logger.debug("样本数据加载完成,分区数为： " + sampleDF.rdd.partitions.size)

    // 加载待分类数据
    val schema2 = new StructType()
      .add("id",DataTypes.DoubleType)
      .add("f1",DataTypes.DoubleType)
      .add("f2",DataTypes.DoubleType)
      .add("f3",DataTypes.DoubleType)
      .add("f4",DataTypes.DoubleType)
      .add("f5",DataTypes.DoubleType)

    val toClassifyDF = spark.read.option("header", "true").schema(schema2).csv("userprofile/data/demo/knn/to_classify")
    logger.debug("待分类数据加载完成,分区数为: " + toClassifyDF.rdd.partitions.size)


    // 将待分类数据和样本数据做连接：笛卡尔积（cross join)
    val joinedDF = (toClassifyDF.toDF("id","b1", "b2", "b3", "b4", "b5")).crossJoin(sampleDF)
    logger.debug("待分类数据 cross join  样本数据完成,分区数为: " + joinedDF.rdd.partitions.size)
    joinedDF.show(100, false)

    /**
      * +-----+---+---+---+---+---+---+---+---+---+---+---+
      * |label|f1 |f2 |f3 |f4 |f5 |id |b1 |b2 |b3 |b4 |b5 |
      * +-----+---+---+---+---+---+---+---+---+---+---+---+
      * |0    |10 |20 |30 |40 |30 |1  |11 |21 |31 |44 |32 |
      * |0    |10 |20 |30 |40 |30 |2  |14 |26 |32 |39 |30 |
      * |0    |10 |20 |30 |40 |30 |3  |32 |14 |21 |42 |32 |
      * |0    |10 |20 |30 |40 |30 |4  |34 |12 |22 |42 |34 |
      * |0    |10 |20 |30 |40 |30 |5  |34 |12 |22 |42 |34 |
      *
      */

    // 自定义一个计算欧氏距离的函数
    import org.apache.spark.sql.functions._
    val eudi: UserDefinedFunction = udf(
      (arr1: mutable.WrappedArray[Double], arr2: mutable.WrappedArray[Double]) => {
        val v1 = Vectors.dense(arr1.toArray)
        val v2 = Vectors.dense(arr2.toArray)
        Vectors.sqdist(v1,v2)
      }
    )

    // 计算 样本向量和未知类别向量 的欧氏距离
    val dist = joinedDF.select(
      'label,
      'id,
      eudi(array('f1,'f2,'f3,'f4,'f5),array('b1,'b2,'b3,'b4,'b5)) as "dist"
    )
    /**
      * +-----+---+-----+
      * |label|id |dist |
      * +-----+---+-----+
      * |0.0  |1.0|23.0 |
      * |0.0  |1.0|19.0 |
      * |0.0  |1.0|20.0 |
      * |0.0  |1.0|10.0 |
      * |0.0  |1.0|13.0 |
      * |0.0  |1.0|15.0 |
      * |1.0  |1.0|581.0|
      * |1.0  |1.0|731.0|
      * |1.0  |1.0|523.0|
      * |1.0  |1.0|688.0|
      * |1.0  |1.0|554.0|
      * |1.0  |1.0|595.0|
      * |0.0  |2.0|238.2|
      * |...............|
      * +-----+---+-----+
      */

    // 计算 一个未知人，距离最近的5个样本中分别都是哪些类别
    dist.createTempView("dist")
    val top5 = spark.sql(
      """
        |select
        |label,
        |id
        |from
        |(
        |   select
        |   label,
        |   id,
        |   dist,
        |   row_number() over(partition by id order by dist) as rn
        |   from dist
        |)
        |where rn <=5
        |
      """.stripMargin)

    /**
      * 每个人距离最近的5个样本的类别标签
      * +-----+---+
      * |label|id |
      * +-----+---+
      * |0.0  |1.0|
      * |0.0  |1.0|
      * |0.0  |1.0|
      * |1.0  |1.0|
      * |0.0  |1.0|
      * |1.0  |4.0|
      * |1.0  |4.0|
      * |0.0  |4.0|
      * |1.0  |4.0|
      * |1.0  |4.0|
      * |1.0  |3.0|
      * |1.0  |3.0|
      * |1.0  |3.0|
      * |1.0  |3.0|
      * |1.0  |3.0|
      * ..........
      * +-----+---+
      */

    // 计算 一个未知人，距离最近的5个样本中，哪个类别最多
    top5.createTempView("top5")
    spark.sql(
      """
        |
        |select
        |id,label
        |from top5
        |group by id,label
        |having count(1)> 2
        |
      """.stripMargin)
        .show(100,false)

    spark.close()
  }

}
